{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "from torch import nn, optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleCNN(nn.Module) :\n",
    "    def __init__(self) :\n",
    "        # b, 3, 32, 32\n",
    "        super().__init__()\n",
    "        layer1 = nn.Sequential()\n",
    "        layer1.add_module('conv_1', nn.Conv2d(3, 32, 3, 1, padding=1))\n",
    "        #b, 32, 32, 32\n",
    "        layer1.add_module('relu_1', nn.ReLU(True))\n",
    "        layer1.add_module('pool_1', nn.MaxPool2d(2, 2)) # b, 32, 16, 16\n",
    "        self.layer1 = layer1\n",
    "        \n",
    "        layer2 = nn.Sequential()\n",
    "        layer2.add_module('conv_2', nn.Conv2d(32, 64, 3, 1, padding=1))\n",
    "        # b, 64, 16, 16\n",
    "        layer2.add_module('relu_2', nn.ReLU(True))\n",
    "        layer2.add_module('pool_2', nn.MaxPool2d(2, 2)) # b, 64, 8, 8\n",
    "        self.layer2 = layer2\n",
    "        \n",
    "        layer3 = nn.Sequential()\n",
    "        layer3.add_module('conv_3', nn.Conv2d(64, 128, 3, 1, padding=1))\n",
    "        # b, 128, 8, 8\n",
    "        layer3.add_module('relu_3', nn.ReLU(True))\n",
    "        layer3.add_module('pool_3', nn.MaxPool2d(2, 2)) # b, 128, 4, 4\n",
    "        self.layer3 = layer3\n",
    "        \n",
    "        layer4 = nn.Sequential()\n",
    "        layer4.add_module('fc_1', nn.Linear(2048, 512))\n",
    "        layer4.add_module('fc_relu1', nn.ReLU(True))\n",
    "        layer4.add_module('fc_2', nn.Linear(512, 64))\n",
    "        layer4.add_module('fc_relu2', nn.ReLU(True))\n",
    "        layer4.add_module('fc_3', nn.Linear(64, 10))\n",
    "        self.layer4 = layer4\n",
    "    \n",
    "    def forward(self, x) :\n",
    "        conv1 = self.layer1(x)\n",
    "        conv2 = self.layer2(conv1)\n",
    "        conv3 = self.layer3(conv2)\n",
    "        fc_input = conv3.view(conv3.size(0), -1)\n",
    "        fc_out = self.layer4(fc_input)\n",
    "        return fc_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SimpleCNN(\n",
      "  (layer1): Sequential(\n",
      "    (conv_1): Conv2d (3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (relu_1): ReLU(inplace)\n",
      "    (pool_1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))\n",
      "  )\n",
      "  (layer2): Sequential(\n",
      "    (conv_2): Conv2d (32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (relu_2): ReLU(inplace)\n",
      "    (pool_2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))\n",
      "  )\n",
      "  (layer3): Sequential(\n",
      "    (conv_3): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (relu_3): ReLU(inplace)\n",
      "    (pool_3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))\n",
      "  )\n",
      "  (layer4): Sequential(\n",
      "    (fc_1): Linear(in_features=2048, out_features=512)\n",
      "    (fc_relu1): ReLU(inplace)\n",
      "    (fc_2): Linear(in_features=512, out_features=64)\n",
      "    (fc_relu2): ReLU(inplace)\n",
      "    (fc_3): Linear(in_features=64, out_features=10)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# 建立模型\n",
    "\n",
    "model = SimpleCNN()\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): Sequential(\n",
      "    (conv_1): Conv2d (3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (relu_1): ReLU(inplace)\n",
      "    (pool_1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))\n",
      "  )\n",
      "  (1): Sequential(\n",
      "    (conv_2): Conv2d (32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (relu_2): ReLU(inplace)\n",
      "    (pool_2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# 提取前两层\n",
    "\n",
    "new_model = nn.Sequential(*list(model.children())[:2])\n",
    "print(new_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (layer1.conv_1): Conv2d (3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "  (layer2.conv_2): Conv2d (32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "  (layer3.conv_3): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# 提取所有的卷积层\n",
    "\n",
    "conv_model = nn.Sequential()\n",
    "for name, module in model.named_modules() :\n",
    "    if isinstance(module, nn.Conv2d) :\n",
    "        conv_model.add_module(name, module)\n",
    "\n",
    "print(conv_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "layer1.conv_1.weight : torch.Size([32, 3, 3, 3])\n",
      "layer1.conv_1.bias : torch.Size([32])\n",
      "layer2.conv_2.weight : torch.Size([64, 32, 3, 3])\n",
      "layer2.conv_2.bias : torch.Size([64])\n",
      "layer3.conv_3.weight : torch.Size([128, 64, 3, 3])\n",
      "layer3.conv_3.bias : torch.Size([128])\n",
      "layer4.fc_1.weight : torch.Size([512, 2048])\n",
      "layer4.fc_1.bias : torch.Size([512])\n",
      "layer4.fc_2.weight : torch.Size([64, 512])\n",
      "layer4.fc_2.bias : torch.Size([64])\n",
      "layer4.fc_3.weight : torch.Size([10, 64])\n",
      "layer4.fc_3.bias : torch.Size([10])\n"
     ]
    }
   ],
   "source": [
    "# 提取模型中的参数\n",
    "\n",
    "for name, param in model.named_parameters() :\n",
    "    print('{} : {}'.format(name, param.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 权重初始化\n",
    "from torch.nn import init\n",
    "\n",
    "for m in model.modules() :\n",
    "    if isinstance(m, nn.Conv2d) :\n",
    "        init.normal(m.weight.data)\n",
    "        init.xavier_normal(m.weight.data)\n",
    "        init.kaiming_normal(m.weight.data)\n",
    "        m.bias.data.fill_(0)\n",
    "    elif isinstance(m, nn.Linear) :\n",
    "        m.weight.data.normal_()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "关于参数初始化可参考[深度学习的weight initialization](https://zhuanlan.zhihu.com/p/25110150)\n",
    "\n",
    "可以从torch的[文档](http://pytorch.org/docs/master/nn.html?highlight=init%20xavier_normal#torch.nn.init.xavier_normal)中得到\n",
    "\n",
    "- `init.xvaier_uniform()`一般用于tanh的初始化，结果采样于均匀分布 $$U(-a, a) \\sim [-\\frac {\\sqrt{6}} {\\sqrt{fan\\_in + fan\\_out}}, \\frac {\\sqrt{6}} {\\sqrt{fan\\_in + fan\\_out}}]$$\n",
    "- `init.xvarier_normal()`，结果采样于正态分布 $$N(0, \\sqrt{\\frac 2 {fan\\_in + fan\\_out}})$$\n",
    "- `init.kaiming_uniform()` 结果采样于均匀分布 $$U(-a, a) \\sim [-\\frac {\\sqrt{6}} {\\sqrt{(1+a^2) \\times fan\\_out}}, \\frac {\\sqrt{6}} {\\sqrt{(1+a^2) \\times fan\\_out}}]$$\n",
    "- `init.kaiming_normal()`一般用于ReLU的初始化，初始化方法为正态分布 $$N(0, \\sqrt{\\frac 2 {(1 + a^2) \\times fan\\_in}})$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
